from torch import nn
from src.nn import (
    SelfAttentionBlock,
    FFN,
    DropPath,
    LayerNorm,
    INDEX_BASED_NORMS)

from src.utils import VersionHolder
__all__ = ['TransformerBlock']


# TODO: Careful with how we define the index for LayerNorm:
#  cluster-wise or cloud-wise ? Maybe cloud-wise, seems more stable...


class TransformerBlock(nn.Module):
    """Base block of the Transformer architecture:

        x ---------------- + ---------------- + -->
            \             |   \              |
             -- N -- SA --     -- N -- FFN --

    Where:
        - N: Normalization
        - SA: Self-Attention
        - FFN: Feed-Forward Network

    Inspired by: https://github.com/microsoft/Swin-Transformer

    :param dim: int
        Dimension of the features space on which the transformer
        block operates
    :param num_heads: int
        Number of attention heads
    :param qkv_bias: bool
        Whether the linear layers producing queries, keys, and
        values should have a bias
    :param qk_dim: int
        Dimension of the queries and keys
    :param qk_scale: str
        Scaling applied to the query*key product before the softmax.
        More specifically, one may want to normalize the query-key
        compatibilities based on the number of dimensions (referred
        to as 'd' here) as in a vanilla Transformer implementation,
        or based on the number of neighbors each node has in the
        attention graph (referred to as 'g' here). If nothing is
        specified the scaling will be `1 / (sqrt(d) * sqrt(g))`,
        which is equivalent to passing `'d.g'`. Passing `'d+g'` will
        yield `1 / (sqrt(d) + sqrt(g))`. Meanwhile, passing 'd' will
        yield `1 / sqrt(d)`, and passing `'g'` will yield
        `1 / sqrt(g)`
    :param in_rpe_dim: int
        Dimension of the features passed as input for relative
        positional encoding computation (i.e. edge features)
    :param ffn_ratio: int
        Multiplicative factor for computing the dimension of the
        `FFN` inverted bottleneck: `ffn_ratio * dim`
    :param attn_drop: float
        Dropout on the attention weights of the `SelfAttentionBlock`
    :param residual_drop: float
        Dropout on the output features of the `SelfAttentionBlock`
    :param drop_path: float
        Dropout on the `SelfAttentionBlock` and `FFN` paths. Contrary
        to other dropout parameters, here we either keep all or none
        features. This allows training with stochastic depth
    :param activation: nn.Module
        Activation function for the `FFN` module
    :param norm: nn.Module
        Normalization function for the `FFN` module
    :param pre_norm: bool
        Whether the normalization should be applied before or after
        the `SelfAttentionBlock` and `FFN` in the residual branches
    :param no_sa: bool
        Whether a self-attention residual branch should be used at
        all
    :param no_ffn: bool
        Whether a feed-forward residual branch should be used at
        all
    :param k_rpe: bool
        Whether keys should receive relative positional encodings
        computed from edge features. See `SelfAttentionBlock`
    :param q_rpe: bool
        Whether queries should receive relative positional encodings
        computed from edge features. See `SelfAttentionBlock`
    :param v_rpe: bool
        Whether values should receive relative positional encodings
        computed from edge features. See `SelfAttentionBlock`
    :param k_delta_rpe: bool
        Whether keys should receive relative positional encodings
        computed from the difference between source and target node
        features. See `SelfAttentionBlock`
    :param q_delta_rpe: bool
        Whether queries should receive relative positional encodings
        computed from the difference between source and target node
        features. See `SelfAttentionBlock`
    :param qk_share_rpe: bool
        Whether queries and keys should use the same parameters for
        building relative positional encodings. See
        `SelfAttentionBlock`
    :param q_on_minus_rpe: bool
        Whether relative positional encodings for queries should be
        computed on the opposite of features used for keys. This allows,
        for instance, to break the symmetry when `qk_share_rpe` but we
        want relative positional encodings to capture different meanings
        for keys and queries. See `SelfAttentionBlock`
    :param heads_share_rpe: bool
        whether attention heads should share the same parameters for
        building relative positional encodings. See
        `SelfAttentionBlock`
    :param version_holder: VersionHolder
        Object storing the code version. Used to determine the definition of the
        residual connection added after the FFN. This is implemented to
        ensure compatibility with the official weights of SPT and SPC released
        before this commit fixing the residual connection.
        (https://github.com/drprojects/superpoint_transformer/commit/a0f753b35b86e06d426113bdeac9b0123b220aa3)
        
        If the version is greater than or equal to 3.0.0, the residual 
        connection is the input value of the FFN.
        Otherwise, it is the input value of the self-attention block.
            
    """

    def __init__(
            self,
            dim,
            num_heads=1,
            qkv_bias=True,
            qk_dim=8,
            qk_scale=None,
            in_rpe_dim=18,
            ffn_ratio=4,
            attn_drop=None,
            residual_drop=None,
            drop_path=None,
            activation=nn.LeakyReLU(),
            norm=LayerNorm,
            pre_norm=True,
            no_sa=False,
            no_ffn=False,
            k_rpe=False,
            q_rpe=False,
            v_rpe=False,
            k_delta_rpe=False,
            q_delta_rpe=False,
            qk_share_rpe=False,
            q_on_minus_rpe=False,
            heads_share_rpe=False,
            version_holder=VersionHolder()):
        super().__init__()

        self.dim = dim
        self.pre_norm = pre_norm
        self.version_holder = version_holder

        # Self-Attention residual branch
        self.no_sa = no_sa
        if not no_sa:
            self.sa_norm = norm(dim)
            self.sa = SelfAttentionBlock(
                dim,
                num_heads=num_heads,
                in_dim=None,
                out_dim=dim,
                qkv_bias=qkv_bias,
                qk_dim=qk_dim,
                qk_scale=qk_scale,
                in_rpe_dim=in_rpe_dim,
                attn_drop=attn_drop,
                drop=residual_drop,
                k_rpe=k_rpe,
                q_rpe=q_rpe,
                v_rpe=v_rpe,
                k_delta_rpe=k_delta_rpe,
                q_delta_rpe=q_delta_rpe,
                qk_share_rpe=qk_share_rpe,
                q_on_minus_rpe=q_on_minus_rpe,
                heads_share_rpe=heads_share_rpe)

        # Feed-Forward Network residual branch
        self.no_ffn = no_ffn
        if not no_ffn:
            self.ffn_norm = norm(dim)
            self.ffn_ratio = ffn_ratio
            self.ffn = FFN(
                dim,
                hidden_dim=int(dim * ffn_ratio),
                activation=activation,
                drop=residual_drop)

        # Optional DropPath module for stochastic depth
        self.drop_path = DropPath(drop_path) \
            if drop_path is not None and drop_path > 0 else nn.Identity()

    def forward(self, 
                x, 
                norm_index, 
                edge_index=None, 
                edge_attr=None):
        """
        :param x: FloatTensor or shape (N, C)
            Node features
        :param norm_index: LongTensor or shape (N)
            Node indices for the LayerNorm
        :param edge_index: LongTensor of shape (2, E)
            Edges in torch_geometric [[sources], [targets]] format for
            the self-attention module
        :param edge_attr: FloatTensor or shape (E, F)
            Edge attributes in torch_geometric format for relative pose
            encoding in the self-attention module
        :return:
        """
        assert x.dim() == 2, 'x should be a 2D Tensor'
        assert x.is_floating_point(), 'x should be a 2D FloatTensor'
        assert norm_index.dim() == 1 and norm_index.shape[0] == x.shape[0], \
            'norm_index should be a 1D LongTensor'
        assert edge_index is None or \
               (edge_index.dim() == 2 and not edge_index.is_floating_point()), \
            'edge_index should be a 2D LongTensor'
        assert edge_attr is None or \
               (edge_attr.dim() == 2 and edge_attr.shape[0] == edge_index.shape[1]),\
            'edge_attr be a 2D LongTensor'

        # Keep track of x for the residual connection
        shortcut = x

        # Self-Attention residual branch. Skip the SA block if no edges
        # are provided
        if self.no_sa or edge_index is None or edge_index.shape[1] == 0:
            pass
        elif self.pre_norm:
            x = self._forward_norm(self.sa_norm, x, norm_index)
            x = self.sa(x, edge_index, edge_attr=edge_attr)
            x = shortcut + self.drop_path(x)
        else:
            x = self.sa(x, edge_index, edge_attr=edge_attr)
            x = self.drop_path(x)
            x = self._forward_norm(self.sa_norm, shortcut + x, norm_index)
            
        # Keep track of x for the residual connection (if version >= 2.2.0)
        if (self.version_holder.major >= 3
            or (self.version_holder.major == 2 and 
                self.version_holder.minor >= 2)):
            shortcut = x

        # Feed-Forward Network residual branch
        if not self.no_ffn and self.pre_norm:
            x = self._forward_norm(self.ffn_norm, x, norm_index)
            x = self.ffn(x)
            x = shortcut + self.drop_path(x)
        if not self.no_ffn and not self.pre_norm:
            x = self.ffn(x)
            x = self.drop_path(x)
            x = self._forward_norm(self.ffn_norm, shortcut + x, norm_index)

        return x, norm_index, edge_index

    @staticmethod
    def _forward_norm(norm, x, norm_index):
        """Simple helper for the forward pass on norm modules. Some
        modules require an index, while others don't.
        """
        if isinstance(norm, INDEX_BASED_NORMS):
            return norm(x, batch=norm_index)
        return norm(x)
